###################################################################################################
# ATTENTION! This test will most probably fail if you install TensorRT 6.0.1 only.
# That's because it's shipped with older version of ONNX parser not supporting some
# required features. To make it work please use new version: https://github.com/onnx/onnx-tensorrt
# Just clone it and do something like this:
#
# ~/pt/third_party/onnx-tensorrt$ mkdir build/
# ~/pt/third_party/onnx-tensorrt$ cd build/
# ~/pt/third_party/onnx-tensorrt/build$ cmake ..
# ~/pt/third_party/onnx-tensorrt/build$ make
# ~/pt/third_party/onnx-tensorrt/build$ sudo cp libnvonnxparser.so.6.0.1 /usr/lib/x86_64-linux-gnu
#
# This note is valid for 6.0.1 release only. September 18th, 2019.
###################################################################################################

import os
import unittest

from PIL import Image
import numpy as np
import torch
import torchvision.models as models

import pycuda.driver as cuda
# This import causes pycuda to automatically manage CUDA context creation and cleanup.
import pycuda.autoinit

import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

def allocate_buffers(engine):
    h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)),
                                    dtype=trt.nptype(trt.float32))
    h_output = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(1)),
                                     dtype=trt.nptype(trt.float32))
    d_input = cuda.mem_alloc(h_input.nbytes)
    d_output = cuda.mem_alloc(h_output.nbytes)
    stream = cuda.Stream()
    return h_input, d_input, h_output, d_output, stream

def load_normalized_test_case(input_shape, test_image, pagelocked_buffer, normalization_hint):
    def normalize_image(image):
        c, h, w = input_shape
        image_arr = np.asarray(image.resize((w, h), Image.ANTIALIAS)).transpose([2, 0, 1])\
            .astype(trt.nptype(trt.float32)).ravel()
        if (normalization_hint == 0):
            return (image_arr / 255.0 - 0.45) / 0.225
        elif (normalization_hint == 1):
            return (image_arr / 256.0 - 0.5)
    np.copyto(pagelocked_buffer, normalize_image(Image.open(test_image)))
    return test_image

class Test_PT_ONNX_TRT(unittest.TestCase):
    def __enter__(self):
        return self

    def setUp(self):
        data_path = os.path.join(os.path.dirname(__file__), 'data')
        self.image_files=["binoculars.jpeg", "reflex_camera.jpeg", "tabby_tiger_cat.jpg"]
        for index, f in enumerate(self.image_files):
            self.image_files[index] = os.path.abspath(os.path.join(data_path, f))
            if not os.path.exists(self.image_files[index]):
                raise FileNotFoundError(self.image_files[index] + " does not exist.")
        with open(os.path.abspath(os.path.join(data_path, "class_labels.txt")), 'r') as f:
            self.labels = f.read().split('\n')

    def build_engine_onnx(self, model_file):
        with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags = 1) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
            builder_config = builder.create_builder_config()
            builder_config.max_workspace_size = 1 << 33
            with open(model_file, 'rb') as model:
                if not parser.parse(model.read()):
                    for error in range(parser.num_errors):
                        self.fail("ERROR: {}".format(parser.get_error(error)))
            return builder.build_engine(network, builder_config)

    def _test_model(self, model_name, input_shape = (3, 224, 224), normalization_hint = 0):

        model = getattr(models, model_name)(pretrained=True)

        shape = (1,) + input_shape
        dummy_input  = (torch.randn(shape),)
        onnx_name = model_name + ".onnx"

        torch.onnx.export(model,
                          dummy_input,
                          onnx_name,
                          input_names = [],
                          output_names = [],
                          verbose=False,
                          export_params=True,
                          opset_version=9)

        with self.build_engine_onnx(onnx_name) as engine:
            h_input, d_input, h_output, d_output, stream = allocate_buffers(engine)
            with engine.create_execution_context() as context:
                err_count = 0
                for index, f in enumerate(self.image_files):
                    test_case = load_normalized_test_case(input_shape, f,\
                        h_input, normalization_hint)
                    cuda.memcpy_htod_async(d_input, h_input, stream)

                    context.execute_async_v2(bindings=[d_input, d_output],
                                             stream_handle=stream.handle)
                    cuda.memcpy_dtoh_async(h_output, d_output, stream)
                    stream.synchronize()

                    amax = np.argmax(h_output)
                    pred = self.labels[amax]
                    if "_".join(pred.split()) not in\
                            os.path.splitext(os.path.basename(test_case))[0]:
                        err_count = err_count + 1
                self.assertLessEqual(err_count, 1, "Too many recognition errors")

    def test_alexnet(self):
        self._test_model("alexnet", (3, 227, 227))

    def test_resnet18(self):
        self._test_model("resnet18")
    def test_resnet34(self):
        self._test_model("resnet34")
    def test_resnet50(self):
        self._test_model("resnet50")
    def test_resnet101(self):
        self._test_model("resnet101")
    @unittest.skip("Takes 2m")
    def test_resnet152(self):
        self._test_model("resnet152")

    def test_resnet50_2(self):
        self._test_model("wide_resnet50_2")
    @unittest.skip("Takes 2m")
    def test_resnet101_2(self):
        self._test_model("wide_resnet101_2")

    def test_squeezenet1_0(self):
        self._test_model("squeezenet1_0")
    def test_squeezenet1_1(self):
        self._test_model("squeezenet1_1")

    def test_googlenet(self):
        self._test_model("googlenet")
    def test_inception_v3(self):
        self._test_model("inception_v3")

    def test_mnasnet0_5(self):
        self._test_model("mnasnet0_5", normalization_hint = 1)
    def test_mnasnet1_0(self):
        self._test_model("mnasnet1_0", normalization_hint = 1)

    def test_mobilenet_v2(self):
        self._test_model("mobilenet_v2", normalization_hint = 1)

    def test_shufflenet_v2_x0_5(self):
        self._test_model("shufflenet_v2_x0_5")
    def test_shufflenet_v2_x1_0(self):
        self._test_model("shufflenet_v2_x1_0")

    def test_vgg11(self):
        self._test_model("vgg11")
    def test_vgg11_bn(self):
        self._test_model("vgg11_bn")
    def test_vgg13(self):
        self._test_model("vgg13")
    def test_vgg13_bn(self):
        self._test_model("vgg13_bn")
    def test_vgg16(self):
        self._test_model("vgg16")
    def test_vgg16_bn(self):
        self._test_model("vgg16_bn")
    def test_vgg19(self):
        self._test_model("vgg19")
    def test_vgg19_bn(self):
        self._test_model("vgg19_bn")

    @unittest.skip("Takes 13m")
    def test_densenet121(self):
        self._test_model("densenet121")
    @unittest.skip("Takes 25m")
    def test_densenet161(self):
        self._test_model("densenet161")
    @unittest.skip("Takes 27m")
    def test_densenet169(self):
        self._test_model("densenet169")
    @unittest.skip("Takes 44m")
    def test_densenet201(self):
        self._test_model("densenet201")

if __name__ == '__main__':
    unittest.main()
